Lightning Module
A lightning module defines a system not just a model
「lightning moduleはモデルだけでなくシステムも定義する」-公式サイトより-
Lightning Moduleは大きく以下の3要素を持っている
深層モデル
optimizer, lr_scheduler
train / val / test ステップ + (loss計算)
通常は「modelクラスとoptimizerと他アレコレを用意して、オレオレforループで学習!」としている部分をすべて一つにまとめているイメージ。
例えば、Lossの計算方法が違ってそれによる出力数が違うときなどのtraining_stepが変化する場合
Lightning Moduleのクラスファイル自体を分けてしまうことができる
異なるクラスとしてファイル分けができるため見通しが良くなる
訓練(training_*)検証(validation_*)テスト(test_*)の各ループはそれぞれ3つのHookを持っている
*_step:モデルに通す作業を記述
*_step_end:モデルに通した後の処理を記述
*_epoch_end:エポックループが終わった後に実行されるもの
ex) training_*の場合 PyTorchで書くと以下の部分がまとめられている
code:pytorch.py
outs = []
for batch in data:
# ----- TRAINING STEP -----
out = training_step(batch) # Lightning Moduleで埋める必要のある関数
# PyTorch内では以下のように書いていた部分
# x, y = batch
# logits = model(x)
# loss = optimizer(logits, y)
# ----- TRAINING STEP -----
loss.backward()
optimizer.step()
optimizer.zero_grad()
training_epoch_end(outs) # Lightning Moduleで任意で埋めてよい関数
検証とテストのループも全く同じ構造を持っている。backwardとoptimizerの更新が無いだけ。
LightningModule内では以下のように書く。
勝手にstepでの返り値がリストにappendされていって*_step_endの引数になっている。
code:lightning.py
class LitModel(pl.LightningModule):
# --- 他の関数は省略
def training_step(self, batch, batch_idx):
"""batchが引数にある ⇒ これをmodelに通して ⇒ lossをreturn"""
loss = ...
return loss
def training_step_end(self, losses):
"""lossの重みを足し合わせたりなどモデル通した後のloss計算"""
# 特殊な例以外であまり使わなさそう
gpu_0_loss = losses0
gpu_1_loss = losses1
return (gpu_0_loss + gpu_1_loss) * 1/2
def training_epoch_end(...)
"""
1エポック終わった後の処理をする
(つまりバッチ単位でなく、エポック全体で適用したい処理をする)
複数のlossの平均やaccを計算して出力したいときに使用する
"""
# こちらは使う可能性が存在する
フレームワークだからoptimizerなど必要なものは先に宣言しておく必要がある。
code:lightning.py
class LitModel(pl.LightningModule):
# --- 他の関数は省略
def configure_optimizers(...)
"""
必須関数!
oprimizer schedulerをreturnするように記載する。辞書型でもよい。
"""
return optimizer, sceduler